{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <div align=\"center\">embedding_lookup</div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "tf.logging.set_verbosity(tf.logging.ERROR) \n",
    "\n",
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[10 30 20 40]\n",
      "[10 30 20 40]\n",
      "[10 20 30]\n",
      "indices[0] = 4 is not in [0, 4)\n",
      "\t [[node embedding_lookup_2 (defined at <ipython-input-2-5e0c607b259a>:17) ]]\n"
     ]
    }
   ],
   "source": [
    "#####################################################################################\n",
    "# <codecell> 简单实例(single tensor)\n",
    "#####################################################################################\n",
    "\n",
    "params = tf.constant([10, 30, 20, 40])\n",
    "\n",
    "print(params.eval())\n",
    "\n",
    "ids = tf.constant([0, 1, 2, 3])\n",
    "print(tf.nn.embedding_lookup(params, ids).eval()) # [10 30 20 40]\n",
    "\n",
    "ids = tf.constant([0, 2, 1])\n",
    "print(tf.nn.embedding_lookup([params], ids).eval()) # [10 20 30]\n",
    "\n",
    "try:\n",
    "    ids = tf.constant([4, 2, 1, 3])\n",
    "    print(tf.nn.embedding_lookup([params], ids).eval()) # throw Exception\n",
    "except Exception as e:\n",
    "    print(e.message) # Error: 4 is not in [0, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "# <codecell> 多Tensor, 每个tensor第一维包含的元素个数相同(其他维度shape必须相同)\n",
    "#####################################################################################\n",
    "\n",
    "param1 = tf.constant([[1, 101], [2, 102], [3, 103], [33, 333]])\n",
    "param2 = tf.constant([[4, 104], [5, 105], [6, 106], [66, 666]])\n",
    "param3 = tf.constant([[7, 107], [8, 108], [9, 109], [99, 999]])\n",
    "\n",
    "params = [param1, param2, param3]\n",
    "ids = tf.constant([0, 1, 2, 3, 4, 7])\n",
    "\n",
    "#\n",
    "# 分析:\n",
    "#   第一个关键数字是\"3\": params由3个tensor组成, ids索引分别从这个3个分区选取\n",
    "#   第二个关键数字是\"12\": 12 = 3(tensor个数) x 4(每个tensor的第一维个数), 即最大id+1\n",
    "#   第三个关键数字是\"4\": 4 = (11 + 1) / 3 (得出每个tensor应该含有id的最大个数)\n",
    "#                        11是ids中最大下标 = 所有tensor第一维元素个数总和\n",
    "#\n",
    "#   \"mod\": id % 3 代表落入哪个tensor中, eg: 7 % 3 = 1(第二个tensor中)\n",
    "#       tensor0包含的ids: [ 0, 3, 6, 9 ]\n",
    "#       tensor1包含的ids: [ 1, 4, 7, 10 ]\n",
    "#       tensor2包含的ids: [ 2, 5, 8, 11 ]\n",
    "#\n",
    "#                      [1, 101]  [2, 102]  [3, 103]  [33, 333]\n",
    "#                         0         3         6          9\n",
    "#                      [4, 104]  [5, 105]  [6, 106]  [66, 666]\n",
    "#                         1         4         7         10\n",
    "#                      [7, 107]  [8, 108]  [9, 109]  [99, 999]\n",
    "#                         2         5         8         11\n",
    "#\n",
    "#   \"div\": id // 4 代表落入哪个tensor中, eg: 7 // 4 = 1(第二个tensor中)\n",
    "#       tensor0包含的ids: [ 0, 1, 2, 3 ]\n",
    "#       tensor1包含的ids: [ 4, 5, 6, 7 ]\n",
    "#       tensor2包含的ids: [ 8, 9, 10, 11 ]\n",
    "#\n",
    "#                      [1, 101]  [2, 102]  [3, 103]  [33, 333]\n",
    "#                         0         1         2          3\n",
    "#                      [4, 104]  [5, 105]  [6, 106]  [66, 666]\n",
    "#                         4         5         6          7\n",
    "#                      [7, 107]  [8, 108]  [9, 109]  [99, 999]\n",
    "#                         8         9        10         11"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  1, 101],\n",
       "       [  4, 104],\n",
       "       [  7, 107],\n",
       "       [  2, 102],\n",
       "       [  5, 105],\n",
       "       [  6, 106]], dtype=int32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mod = tf.nn.embedding_lookup(params, ids)\n",
    "mod.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  1, 101],\n",
       "       [  2, 102],\n",
       "       [  3, 103],\n",
       "       [ 33, 333],\n",
       "       [  4, 104],\n",
       "       [ 66, 666]], dtype=int32)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "div = tf.nn.embedding_lookup(params, ids, partition_strategy='div')\n",
    "div.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  1, 101],\n",
       "       [  5, 105],\n",
       "       [  7, 107],\n",
       "       [  2, 102],\n",
       "       [  6, 106],\n",
       "       [  8, 108],\n",
       "       [  9, 109]], dtype=int32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#####################################################################################\n",
    "# <codecell> 多Tensor, 每个tensor第一维包含的元素个数不同(其他维度shape必须相同)\n",
    "#####################################################################################\n",
    "\n",
    "param1 = tf.constant([[1, 101], [2, 102], [3, 103], [4, 104], [33, 333]])\n",
    "param2 = tf.constant([[5, 105], [6, 106]])\n",
    "param3 = tf.constant([[7, 107], [8, 108], [9, 109], [99, 999]])\n",
    "\n",
    "params = [param1, param2, param3]\n",
    "ids = tf.constant([0, 1, 2, 3, 4, 5, 8])\n",
    "\n",
    "# 分析:\n",
    "#     3个tensor\n",
    "#     11个元素: 9 = 5 + 2 + 4\n",
    "#     每个tensor含有id最大个数 4 = (11 + 1) / 3\n",
    "#\n",
    "#     \"mod\":\n",
    "#       tensor0包含的ids: [ 0, 3, 6, 9 ]\n",
    "#       tensor1包含的ids: [ 1, 4, 7, 10 ]\n",
    "#       tensor2包含的ids: [ 2, 5, 8, 11 ]\n",
    "#\n",
    "#                    [1, 101]  [2, 102]  [3, 103], [4, 104], [33, 333]\n",
    "#                       0         3         6         9\n",
    "#                    [5, 105]  [6, 106]    ---       ---\n",
    "#                       1         4         7         10\n",
    "#                    [7, 107]  [8, 108], [9, 109], [9, 999]\n",
    "#                       2         5         8         11\n",
    "#      注意: 如果ids中含有7 || 10, 将会报错, 该位置是空\n",
    "#\n",
    "#      \"div\":\n",
    "#       tensor0包含的ids: [ 0, 1, 2, 3 ]\n",
    "#       tensor1包含的ids: [ 4, 5, 6, 7 ]\n",
    "#       tensor2包含的ids: [ 8, 9, 10, 11 ]\n",
    "#\n",
    "#                    [1, 101]  [2, 102]  [3, 103], [4, 104], [33, 333]\n",
    "#                       0         1         2         3\n",
    "#                    [5, 105]  [6, 106]    ---       ---\n",
    "#                       4         5         6         7\n",
    "#                    [7, 107]  [8, 108], [9, 109], [99, 999]\n",
    "#                       8         9         10        11\n",
    "#       注意: 如果ids中含有 6 || 7, 将会报错, 该位置是空\n",
    "\n",
    "mod = tf.nn.embedding_lookup(params, ids)\n",
    "mod.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[  1, 101],\n",
       "       [  2, 102],\n",
       "       [  3, 103],\n",
       "       [  4, 104],\n",
       "       [  5, 105],\n",
       "       [  6, 106],\n",
       "       [  7, 107]], dtype=int32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "div = tf.nn.embedding_lookup(params, ids, partition_strategy='div')\n",
    "div.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "indices[0] = 2 is not in [0, 2)\n",
      "indices[0] = 3 is not in [0, 2)\n",
      "\t [[node embedding_lookup_8/GatherV2_1 (defined at <ipython-input-8-aca72c34f8cb>:10) ]]\n"
     ]
    }
   ],
   "source": [
    "ids = tf.constant([7])\n",
    "\n",
    "try:\n",
    "    mod = tf.nn.embedding_lookup(params, ids)\n",
    "    print(mod.eval())\n",
    "except: # noqa: E722\n",
    "    print(\"indices[0] = 2 is not in [0, 2)\")\n",
    "\n",
    "try:\n",
    "    div = tf.nn.embedding_lookup(params, ids, partition_strategy='div')\n",
    "    print(div.eval())\n",
    "except Exception as e:\n",
    "    print(e.message) # \"indices[0] = 3 is not in [0, 2)\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
